{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%HTML\n",
    "<style>\n",
    "    body {\n",
    "        --vscode-font-family: \"Roboto Thin\"\n",
    "    }\n",
    "</style>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<center>\n",
    "<h1> SuGaR: Surface-Aligned Gaussian Splatting for Efficient 3D Mesh Reconstruction \n",
    "    <br>and High-Quality Mesh Rendering</h1>\n",
    "Antoine Guédon and Vincent Lepetit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import open3d as o3d\n",
    "from pytorch3d.io import load_objs_as_meshes\n",
    "from pytorch3d.renderer import (\n",
    "    AmbientLights,\n",
    "    RasterizationSettings, \n",
    "    MeshRenderer, \n",
    "    MeshRasterizer,  \n",
    "    SoftPhongShader,\n",
    "    )\n",
    "from pytorch3d.renderer.blending import BlendParams\n",
    "from sugar_scene.gs_model import GaussianSplattingWrapper\n",
    "from sugar_scene.sugar_model import SuGaR, load_refined_model\n",
    "from sugar_utils.spherical_harmonics import SH2RGB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "numGPU = 0\n",
    "torch.cuda.set_device(numGPU)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data and vanilla Gaussian Splatting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ========== Loading parameters ==========\n",
    "use_eval_split = False\n",
    "n_skip_images_for_eval_split = 8\n",
    "\n",
    "iteration_to_load = 7000\n",
    "# iteration_to_load = 30_000\n",
    "\n",
    "load_gt_images = False\n",
    "use_custom_bbox = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Choose a data directory (the directory that contains the images subdirectory)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example\n",
    "source_path = './data/nerfstudio/qant03/'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Choose a corresponding vanilla Gaussian Splatting checkpoint directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example\n",
    "gs_checkpoint_path = './gaussian_splatting/output/7bdae844-6/'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load data and 3DGS checkpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ====================Load NeRF model and training data====================\n",
    "\n",
    "# Load Gaussian Splatting checkpoint \n",
    "print(f\"\\nLoading config {gs_checkpoint_path}...\")\n",
    "if use_eval_split:\n",
    "    print(\"Performing train/eval split...\")\n",
    "nerfmodel = GaussianSplattingWrapper(\n",
    "    source_path=source_path,\n",
    "    output_path=gs_checkpoint_path,\n",
    "    iteration_to_load=iteration_to_load,\n",
    "    load_gt_images=load_gt_images,\n",
    "    eval_split=use_eval_split,\n",
    "    eval_split_interval=n_skip_images_for_eval_split,\n",
    "    )\n",
    "\n",
    "print(f'{len(nerfmodel.training_cameras)} training images detected.')\n",
    "print(f'The model has been trained for {iteration_to_load} steps.')\n",
    "print(len(nerfmodel.gaussians._xyz) / 1e6, \"M gaussians detected.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Render with a refined SuGaR model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Choose a corresponding refined SuGaR checkpoint directory (located in `refined/<your scene>`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example\n",
    "refined_sugar_folder = \"./output/refined/qant03/sugarfine_3Dgs7000_densityestim02_sdfnorm02_level03_decim1000000_normalconsistency01_gaussperface1/\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Choose a refinement iteration to load."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "refined_iteration_to_load = 15_000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the refined SuGaR checkpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "refined_sugar_path = os.path.join(refined_sugar_folder, f\"{refined_iteration_to_load}.pt\")\n",
    "print(f\"\\nLoading config {refined_sugar_path}...\")\n",
    "\n",
    "refined_sugar = load_refined_model(refined_sugar_path, nerfmodel)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Render an image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "onsamerow = True\n",
    "also_render_vanilla_3dgs = False\n",
    "\n",
    "# -----Camera to render-----\n",
    "cameras_to_use = nerfmodel.training_cameras\n",
    "# cameras_to_use = nerfmodel.test_cameras\n",
    "\n",
    "cam_idx = np.random.randint(0, len(cameras_to_use.gs_cameras))\n",
    "# --------------------------\n",
    "\n",
    "refined_sugar.eval()\n",
    "refined_sugar.adapt_to_cameras(cameras_to_use)\n",
    "\n",
    "print(f\"Rendering image with index {cam_idx}.\")\n",
    "print(\"Image name:\", cameras_to_use.gs_cameras[cam_idx].image_name)\n",
    "\n",
    "verbose = False\n",
    "normalize_img = False\n",
    "\n",
    "if load_gt_images:\n",
    "    gt_rgb = nerfmodel.get_gt_image(cam_idx)\n",
    "    i_sugar = 2\n",
    "else:\n",
    "    i_sugar = 1\n",
    "\n",
    "with torch.no_grad():\n",
    "    if also_render_vanilla_3dgs:\n",
    "        gs_image = nerfmodel.render_image(\n",
    "            nerf_cameras=cameras_to_use,\n",
    "            camera_indices=cam_idx).clamp(min=0, max=1)\n",
    "    \n",
    "    sugar_image = refined_sugar.render_image_gaussian_rasterizer(\n",
    "        nerf_cameras=cameras_to_use, \n",
    "        camera_indices=cam_idx,\n",
    "        # bg_color=1. * torch.Tensor([1.0, 1.0, 1.0]).to(rc_fine.device),\n",
    "        sh_deg=nerfmodel.gaussians.active_sh_degree,\n",
    "        compute_color_in_rasterizer=True,\n",
    "    ).nan_to_num().clamp(min=0, max=1)\n",
    "\n",
    "# Change this to adjust the size of the plot\n",
    "plot_ratio = 2. # 0.7, 1.5, 5\n",
    "\n",
    "if also_render_vanilla_3dgs:\n",
    "    plt.figure(figsize=(10 * plot_ratio, 10 * plot_ratio))\n",
    "    plt.axis(\"off\")\n",
    "    plt.title(\"Vanilla 3DGS render\")\n",
    "    plt.imshow(gs_image.cpu().numpy())\n",
    "    plt.show()\n",
    "plt.figure(figsize=(10 * plot_ratio, 10 * plot_ratio))\n",
    "plt.axis(\"off\")\n",
    "plt.title(\"Refined SuGaR render\")\n",
    "plt.imshow(sugar_image.cpu().numpy())\n",
    "plt.show()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Render with a traditional color texture for SuGaR mesh"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Select the path to the textured mesh (i.e. the obj file in `refined_mesh/<your scene>`).<br>\n",
    "If None, the path to the mesh will be automatically computed from the checkpoint path."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "refined_mesh_path = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load mesh."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if refined_mesh_path is None:\n",
    "    post_processed = False\n",
    "\n",
    "    if post_processed:\n",
    "        post_processed_str = '_postprocessed'\n",
    "    else:\n",
    "        post_processed_str = ''\n",
    "\n",
    "    scene_name = refined_sugar_path.split('/')[-3]\n",
    "    refined_mesh_dir = './output/refined_mesh'\n",
    "    refined_mesh_path = os.path.join(\n",
    "        refined_mesh_dir, scene_name,\n",
    "        refined_sugar_path.split('/')[-2].split('.')[0] + '.obj'\n",
    "    )\n",
    "    \n",
    "print(f\"Loading refined mesh from {refined_mesh_path}, this could take a minute...\")\n",
    "textured_mesh = load_objs_as_meshes([refined_mesh_path]).to(nerfmodel.device)\n",
    "print(f\"Loaded textured mesh with {len(textured_mesh.verts_list()[0])} vertices and {len(textured_mesh.faces_list()[0])} faces.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -----Camera index to render-----\n",
    "cam_idx = np.random.randint(0, len(cameras_to_use))\n",
    "# --------------------------\n",
    "\n",
    "faces_per_pixel = 1\n",
    "max_faces_per_bin = 50_000\n",
    "\n",
    "mesh_raster_settings = RasterizationSettings(\n",
    "    image_size=(refined_sugar.image_height, refined_sugar.image_width),\n",
    "    blur_radius=0.0, \n",
    "    faces_per_pixel=faces_per_pixel,\n",
    "    # max_faces_per_bin=max_faces_per_bin\n",
    ")\n",
    "lights = AmbientLights(device=nerfmodel.device)\n",
    "rasterizer = MeshRasterizer(\n",
    "        cameras=cameras_to_use.p3d_cameras[cam_idx], \n",
    "        raster_settings=mesh_raster_settings,\n",
    "    )\n",
    "renderer = MeshRenderer(\n",
    "    rasterizer=rasterizer,\n",
    "    shader=SoftPhongShader(\n",
    "        device=refined_sugar.device, \n",
    "        cameras=cameras_to_use.p3d_cameras[cam_idx],\n",
    "        lights=lights,\n",
    "        # blend_params=BlendParams(background_color=(0.0, 0.0, 0.0)),\n",
    "        blend_params=BlendParams(background_color=(1.0, 1.0, 1.0)),\n",
    "    )\n",
    ")\n",
    "\n",
    "with torch.no_grad():    \n",
    "    print(\"Rendering image\", cam_idx)\n",
    "    print(\"Image ID:\", cameras_to_use.gs_cameras[cam_idx].image_name)\n",
    "    \n",
    "    p3d_cameras = cameras_to_use.p3d_cameras[cam_idx]\n",
    "    rgb_img = renderer(textured_mesh, cameras=p3d_cameras)[0, ..., :3]\n",
    "    \n",
    "# Change this to adjust the size of the plot\n",
    "plot_ratio = 2.\n",
    "\n",
    "plt.figure(figsize=(10 * plot_ratio, 10 * plot_ratio))\n",
    "plt.axis(\"off\")\n",
    "plt.title(\"Refined SuGaR mesh with a traditional color UV texture\")\n",
    "plt.imshow(rgb_img.cpu().numpy())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sugar",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
